{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction\n",
    "This notebook shows how to use TimesFM with finetuning. \n",
    "\n",
    "In order to perform finetuning, you need to create the Pytorch Dataset in a proper format. The example of the Dataset is provided below.\n",
    "The finetuning code can be found in timesfm.finetuning_torch.py. This notebook just imports the methods from finetuning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset Creation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TimesFM v1.2.0. See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs.\n",
      "Loaded Jax TimesFM.\n",
      "Loaded PyTorch TimesFM.\n"
     ]
    }
   ],
   "source": [
    "from os import path\n",
    "from typing import Optional, Tuple\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.multiprocessing as mp\n",
    "import yfinance as yf\n",
    "from finetuning.finetuning_torch import FinetuningConfig, TimesFMFinetuner\n",
    "from huggingface_hub import snapshot_download\n",
    "from torch.utils.data import Dataset\n",
    "\n",
    "from timesfm import TimesFm, TimesFmCheckpoint, TimesFmHparams\n",
    "from timesfm.pytorch_patched_decoder import PatchedTimeSeriesDecoder\n",
    "import os\n",
    "\n",
    "\n",
    "class TimeSeriesDataset(Dataset):\n",
    "  \"\"\"Dataset for time series data compatible with TimesFM.\"\"\"\n",
    "\n",
    "  def __init__(self,\n",
    "               series: np.ndarray,\n",
    "               context_length: int,\n",
    "               horizon_length: int,\n",
    "               freq_type: int = 0):\n",
    "    \"\"\"\n",
    "        Initialize dataset.\n",
    "\n",
    "        Args:\n",
    "            series: Time series data\n",
    "            context_length: Number of past timesteps to use as input\n",
    "            horizon_length: Number of future timesteps to predict\n",
    "            freq_type: Frequency type (0, 1, or 2)\n",
    "        \"\"\"\n",
    "    if freq_type not in [0, 1, 2]:\n",
    "      raise ValueError(\"freq_type must be 0, 1, or 2\")\n",
    "\n",
    "    self.series = series\n",
    "    self.context_length = context_length\n",
    "    self.horizon_length = horizon_length\n",
    "    self.freq_type = freq_type\n",
    "    self._prepare_samples()\n",
    "\n",
    "  def _prepare_samples(self) -> None:\n",
    "    \"\"\"Prepare sliding window samples from the time series.\"\"\"\n",
    "    self.samples = []\n",
    "    total_length = self.context_length + self.horizon_length\n",
    "\n",
    "    for start_idx in range(0, len(self.series) - total_length + 1):\n",
    "      end_idx = start_idx + self.context_length\n",
    "      x_context = self.series[start_idx:end_idx]\n",
    "      x_future = self.series[end_idx:end_idx + self.horizon_length]\n",
    "      self.samples.append((x_context, x_future))\n",
    "\n",
    "  def __len__(self) -> int:\n",
    "    return len(self.samples)\n",
    "\n",
    "  def __getitem__(\n",
    "      self, index: int\n",
    "  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n",
    "    x_context, x_future = self.samples[index]\n",
    "\n",
    "    x_context = torch.tensor(x_context, dtype=torch.float32)\n",
    "    x_future = torch.tensor(x_future, dtype=torch.float32)\n",
    "\n",
    "    input_padding = torch.zeros_like(x_context)\n",
    "    freq = torch.tensor([self.freq_type], dtype=torch.long)\n",
    "\n",
    "    return x_context, input_padding, freq, x_future\n",
    "\n",
    "def prepare_datasets(series: np.ndarray,\n",
    "                     context_length: int,\n",
    "                     horizon_length: int,\n",
    "                     freq_type: int = 0,\n",
    "                     train_split: float = 0.8) -> Tuple[Dataset, Dataset]:\n",
    "  \"\"\"\n",
    "    Prepare training and validation datasets from time series data.\n",
    "\n",
    "    Args:\n",
    "        series: Input time series data\n",
    "        context_length: Number of past timesteps to use\n",
    "        horizon_length: Number of future timesteps to predict\n",
    "        freq_type: Frequency type (0, 1, or 2)\n",
    "        train_split: Fraction of data to use for training\n",
    "\n",
    "    Returns:\n",
    "        Tuple of (train_dataset, val_dataset)\n",
    "    \"\"\"\n",
    "  train_size = int(len(series) * train_split)\n",
    "  train_data = series[:train_size]\n",
    "  val_data = series[train_size:]\n",
    "\n",
    "  # Create datasets with specified frequency type\n",
    "  train_dataset = TimeSeriesDataset(train_data,\n",
    "                                    context_length=context_length,\n",
    "                                    horizon_length=horizon_length,\n",
    "                                    freq_type=freq_type)\n",
    "\n",
    "  val_dataset = TimeSeriesDataset(val_data,\n",
    "                                  context_length=context_length,\n",
    "                                  horizon_length=horizon_length,\n",
    "                                  freq_type=freq_type)\n",
    "\n",
    "  return train_dataset, val_dataset\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model Creation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_model(load_weights: bool = False):\n",
    "  device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "  repo_id = \"google/timesfm-2.0-500m-pytorch\"\n",
    "  hparams = TimesFmHparams(\n",
    "      backend=device,\n",
    "      per_core_batch_size=32,\n",
    "      horizon_len=128,\n",
    "      num_layers=50,\n",
    "      use_positional_embedding=False,\n",
    "      context_len=\n",
    "      192,  # Context length can be anything up to 2048 in multiples of 32\n",
    "  )\n",
    "  tfm = TimesFm(hparams=hparams,\n",
    "                checkpoint=TimesFmCheckpoint(huggingface_repo_id=repo_id))\n",
    "\n",
    "  model = PatchedTimeSeriesDecoder(tfm._model_config)\n",
    "  if load_weights:\n",
    "    checkpoint_path = path.join(snapshot_download(repo_id), \"torch_model.ckpt\")\n",
    "    loaded_checkpoint = torch.load(checkpoint_path, weights_only=True)\n",
    "    model.load_state_dict(loaded_checkpoint)\n",
    "  return model, hparams, tfm._model_config\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_predictions(\n",
    "    model: TimesFm,\n",
    "    val_dataset: Dataset,\n",
    "    save_path: Optional[str] = \"predictions.png\",\n",
    ") -> None:\n",
    "  \"\"\"\n",
    "    Plot model predictions against ground truth for a batch of validation data.\n",
    "\n",
    "    Args:\n",
    "      model: Trained TimesFM model\n",
    "      val_dataset: Validation dataset\n",
    "      save_path: Path to save the plot\n",
    "    \"\"\"\n",
    "  import matplotlib.pyplot as plt\n",
    "\n",
    "  model.eval()\n",
    "\n",
    "  x_context, x_padding, freq, x_future = val_dataset[0]\n",
    "  x_context = x_context.unsqueeze(0)  # Add batch dimension\n",
    "  x_padding = x_padding.unsqueeze(0)\n",
    "  freq = freq.unsqueeze(0)\n",
    "  x_future = x_future.unsqueeze(0)\n",
    "\n",
    "  device = next(model.parameters()).device\n",
    "  x_context = x_context.to(device)\n",
    "  x_padding = x_padding.to(device)\n",
    "  freq = freq.to(device)\n",
    "  x_future = x_future.to(device)\n",
    "\n",
    "  with torch.no_grad():\n",
    "    predictions = model(x_context, x_padding.float(), freq)\n",
    "    predictions_mean = predictions[..., 0]  # [B, N, horizon_len]\n",
    "    last_patch_pred = predictions_mean[:, -1, :]  # [B, horizon_len]\n",
    "\n",
    "  context_vals = x_context[0].cpu().numpy()\n",
    "  future_vals = x_future[0].cpu().numpy()\n",
    "  pred_vals = last_patch_pred[0].cpu().numpy()\n",
    "\n",
    "  context_len = len(context_vals)\n",
    "  horizon_len = len(future_vals)\n",
    "\n",
    "  plt.figure(figsize=(12, 6))\n",
    "\n",
    "  plt.plot(range(context_len),\n",
    "           context_vals,\n",
    "           label=\"Historical Data\",\n",
    "           color=\"blue\",\n",
    "           linewidth=2)\n",
    "\n",
    "  plt.plot(\n",
    "      range(context_len, context_len + horizon_len),\n",
    "      future_vals,\n",
    "      label=\"Ground Truth\",\n",
    "      color=\"green\",\n",
    "      linestyle=\"--\",\n",
    "      linewidth=2,\n",
    "  )\n",
    "\n",
    "  plt.plot(range(context_len, context_len + horizon_len),\n",
    "           pred_vals,\n",
    "           label=\"Prediction\",\n",
    "           color=\"red\",\n",
    "           linewidth=2)\n",
    "\n",
    "  plt.xlabel(\"Time Step\")\n",
    "  plt.ylabel(\"Value\")\n",
    "  plt.title(\"TimesFM Predictions vs Ground Truth\")\n",
    "  plt.legend()\n",
    "  plt.grid(True)\n",
    "\n",
    "  if save_path:\n",
    "    plt.savefig(save_path)\n",
    "    print(f\"Plot saved to {save_path}\")\n",
    "\n",
    "  plt.close()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data(context_len: int,\n",
    "             horizon_len: int,\n",
    "             freq_type: int = 0) -> Tuple[Dataset, Dataset]:\n",
    "  df = yf.download(\"AAPL\", start=\"2010-01-01\", end=\"2019-01-01\")\n",
    "  time_series = df[\"Close\"].values\n",
    "\n",
    "  train_dataset, val_dataset = prepare_datasets(\n",
    "      series=time_series,\n",
    "      context_length=context_len,\n",
    "      horizon_length=horizon_len,\n",
    "      freq_type=freq_type,\n",
    "      train_split=0.8,\n",
    "  )\n",
    "\n",
    "  print(f\"Created datasets:\")\n",
    "  print(f\"- Training samples: {len(train_dataset)}\")\n",
    "  print(f\"- Validation samples: {len(val_dataset)}\")\n",
    "  print(f\"- Using frequency type: {freq_type}\")\n",
    "  return train_dataset, val_dataset\n",
    "\n",
    "\n",
    "\n",
    "def single_gpu_example():\n",
    "  \"\"\"Basic example of finetuning TimesFM on stock data.\"\"\"\n",
    "  model, hparams, tfm_config = get_model(load_weights=True)\n",
    "  config = FinetuningConfig(batch_size=256,\n",
    "                            num_epochs=5,\n",
    "                            learning_rate=1e-4,\n",
    "                            use_wandb=True,\n",
    "                            freq_type=1,\n",
    "                            log_every_n_steps=10,\n",
    "                            val_check_interval=0.5,\n",
    "                            use_quantile_loss=True)\n",
    "\n",
    "  train_dataset, val_dataset = get_data(128,\n",
    "                                        tfm_config.horizon_len,\n",
    "                                        freq_type=config.freq_type)\n",
    "  finetuner = TimesFMFinetuner(model, config)\n",
    "\n",
    "  print(\"\\nStarting finetuning...\")\n",
    "  results = finetuner.finetune(train_dataset=train_dataset,\n",
    "                               val_dataset=val_dataset)\n",
    "\n",
    "  print(\"\\nFinetuning completed!\")\n",
    "  print(f\"Training history: {len(results['history']['train_loss'])} epochs\")\n",
    "\n",
    "  plot_predictions(\n",
    "      model=model,\n",
    "      val_dataset=val_dataset,\n",
    "      save_path=\"timesfm_predictions.png\",\n",
    "  )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ac84aeda3a1749ae8f30b06859067bb1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6d9d8081fc514c6d8601a2e0e63954a2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[*********************100%***********************]  1 of 1 completed\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Created datasets:\n",
      "- Training samples: 1556\n",
      "- Validation samples: 198\n",
      "- Using frequency type: 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmishacamry\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.19.1"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>/home/chertushkin/forks/timesfm/notebooks/wandb/run-20250217_114343-tjs63ml2</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/mishacamry/timesfm-finetuning/runs/tjs63ml2' target=\"_blank\">chocolate-eon-50</a></strong> to <a href='https://wandb.ai/mishacamry/timesfm-finetuning' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/mishacamry/timesfm-finetuning' target=\"_blank\">https://wandb.ai/mishacamry/timesfm-finetuning</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/mishacamry/timesfm-finetuning/runs/tjs63ml2' target=\"_blank\">https://wandb.ai/mishacamry/timesfm-finetuning/runs/tjs63ml2</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Starting finetuning...\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<br>    <style><br>        .wandb-row {<br>            display: flex;<br>            flex-direction: row;<br>            flex-wrap: wrap;<br>            justify-content: flex-start;<br>            width: 100%;<br>        }<br>        .wandb-col {<br>            display: flex;<br>            flex-direction: column;<br>            flex-basis: 100%;<br>            flex: 1;<br>            padding: 10px;<br>        }<br>    </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▃▅▆█</td></tr><tr><td>learning_rate</td><td>▁▁▁▁▁</td></tr><tr><td>train_loss</td><td>█▃▂▁▁</td></tr><tr><td>val_loss</td><td>█▁▄▁▂</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>5</td></tr><tr><td>learning_rate</td><td>0.0001</td></tr><tr><td>train_loss</td><td>2.85423</td></tr><tr><td>val_loss</td><td>26.7628</td></tr></table><br/></div></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run <strong style=\"color:#cdcd00\">chocolate-eon-50</strong> at: <a href='https://wandb.ai/mishacamry/timesfm-finetuning/runs/tjs63ml2' target=\"_blank\">https://wandb.ai/mishacamry/timesfm-finetuning/runs/tjs63ml2</a><br> View project at: <a href='https://wandb.ai/mishacamry/timesfm-finetuning' target=\"_blank\">https://wandb.ai/mishacamry/timesfm-finetuning</a><br>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Find logs at: <code>./wandb/run-20250217_114343-tjs63ml2/logs</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Finetuning completed!\n",
      "Training history: 5 epochs\n",
      "Plot saved to timesfm_predictions.png\n"
     ]
    }
   ],
   "source": [
    "single_gpu_example()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "timesfm-DnAbSweh-py3.11",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
